{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn import metrics\n",
    "import transformers\n",
    "import time\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n",
    "from transformers import BertTokenizer, BertModel, BertConfig\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = pd.read_csv('C:/Users/Administrator/Desktop/Pytorch_BERT/data/train_set.csv', sep='\\t')\n",
    "test_df = pd.read_csv('C:/Users/Administrator/Desktop/Pytorch_BERT/data/test_a.csv', sep='\\t')\n",
    "test_df['label'] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>2967 6758 339 2021 1854 3731 4109 3792 4149 15...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>11</td>\n",
       "      <td>4464 486 6352 5619 2465 4802 1452 3137 5778 54...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>7346 4068 5074 3747 5681 6093 1777 2226 7354 6...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>7159 948 4866 2109 5520 2490 211 3956 5520 549...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3</td>\n",
       "      <td>3646 3055 3055 2490 4659 6065 3370 5814 2465 5...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   label                                               text\n",
       "0      2  2967 6758 339 2021 1854 3731 4109 3792 4149 15...\n",
       "1     11  4464 486 6352 5619 2465 4802 1452 3137 5778 54...\n",
       "2      3  7346 4068 5074 3747 5681 6093 1777 2226 7354 6...\n",
       "3      2  7159 948 4866 2109 5520 2490 211 3956 5520 549...\n",
       "4      3  3646 3055 3055 2490 4659 6065 3370 5814 2465 5..."
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [2, 280, 1106, 1529, 518, 193, 3], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('C:/Users/Administrator/Desktop/Pytorch_BERT/emb/bert-mini/')\n",
    "tokenizer.encode_plus(\"2967 6758 339 2021 1854\",\n",
    "        add_special_tokens=True,\n",
    "        max_length=20,\n",
    "        truncation=True)\n",
    "# token_type_ids 通常第一个句子全部标记为0，第二个句子全部标记为1。\n",
    "# attention_mask padding的地方为0，未padding的地方为1。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomDataset(Dataset):\n",
    "\n",
    "    def __init__(self, dataframe, tokenizer, max_len):\n",
    "        self.tokenizer = tokenizer\n",
    "        self.data = dataframe\n",
    "        self.comment_text = self.data.text\n",
    "        self.targets = self.data.label\n",
    "        self.max_len = max_len\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.comment_text)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        comment_text = self.comment_text[index]\n",
    "\n",
    "        inputs = self.tokenizer.encode_plus(\n",
    "            comment_text,\n",
    "            truncation=True,\n",
    "            add_special_tokens=True,\n",
    "            max_length=self.max_len,\n",
    "            pad_to_max_length=True,\n",
    "            return_token_type_ids=True\n",
    "        )\n",
    "        ids = inputs['input_ids']\n",
    "        mask = inputs['attention_mask']\n",
    "        token_type_ids = inputs[\"token_type_ids\"]\n",
    "\n",
    "        return {\n",
    "            'ids': torch.tensor(ids, dtype=torch.long),\n",
    "            'mask': torch.tensor(mask, dtype=torch.long),\n",
    "            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),\n",
    "            'targets': torch.tensor(self.targets[index], dtype=torch.float)\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FULL Dataset: (200000, 2)\n",
      "TRAIN Dataset: (160000, 2)\n",
      "VALID Dataset: (40000, 2)\n",
      "TEST Dataset: (50000, 2)\n"
     ]
    }
   ],
   "source": [
    "# Creating the dataset and dataloader for the neural network\n",
    "MAX_LEN = 256\n",
    "train_size = 0.8\n",
    "train_dataset = train_df.sample(frac=train_size,random_state=7)\n",
    "valid_dataset = train_df.drop(train_dataset.index).reset_index(drop=True)\n",
    "train_dataset = train_dataset.reset_index(drop=True)\n",
    "\n",
    "\n",
    "print(\"FULL Dataset: {}\".format(train_df.shape))\n",
    "print(\"TRAIN Dataset: {}\".format(train_dataset.shape))\n",
    "print(\"VALID Dataset: {}\".format(valid_dataset.shape))\n",
    "print(\"TEST Dataset: {}\".format(test_df.shape))\n",
    "\n",
    "train_set = CustomDataset(train_dataset, tokenizer, MAX_LEN)\n",
    "valid_set = CustomDataset(valid_dataset, tokenizer, MAX_LEN)\n",
    "test_set = CustomDataset(test_df, tokenizer, MAX_LEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_BATCH_SIZE = 32\n",
    "VALID_BATCH_SIZE = 16\n",
    "TEST_BATCH_SIZE = 16\n",
    "train_params = {'batch_size': TRAIN_BATCH_SIZE,\n",
    "                'shuffle': True}\n",
    "\n",
    "valid_params = {'batch_size': VALID_BATCH_SIZE,\n",
    "                'shuffle': True}\n",
    "\n",
    "test_params = {'batch_size': TEST_BATCH_SIZE,\n",
    "                'shuffle': False}\n",
    "\n",
    "train_loader = DataLoader(train_set, **train_params)\n",
    "valid_loader = DataLoader(valid_set, **valid_params)\n",
    "test_loader = DataLoader(test_set, **test_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BERTClass(\n",
       "  (l1): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(5981, 256, padding_idx=0)\n",
       "      (position_embeddings): Embedding(256, 256)\n",
       "      (token_type_embeddings): Embedding(2, 256)\n",
       "      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-3): 4 x BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=256, out_features=256, bias=True)\n",
       "              (key): Linear(in_features=256, out_features=256, bias=True)\n",
       "              (value): Linear(in_features=256, out_features=256, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=256, out_features=256, bias=True)\n",
       "              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=256, out_features=1024, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=1024, out_features=256, bias=True)\n",
       "            (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pooler): BertPooler(\n",
       "      (dense): Linear(in_features=256, out_features=256, bias=True)\n",
       "      (activation): Tanh()\n",
       "    )\n",
       "  )\n",
       "  (bilstm1): LSTM(512, 64, bidirectional=True)\n",
       "  (l2): Linear(in_features=128, out_features=64, bias=True)\n",
       "  (a1): ReLU()\n",
       "  (l3): Dropout(p=0.3, inplace=False)\n",
       "  (l4): Linear(in_features=64, out_features=14, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model. \n",
    "\n",
    "class BERTClass(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(BERTClass, self).__init__()\n",
    "        self.config = BertConfig.from_pretrained('C:/Users/Administrator/Desktop/Pytorch_BERT/emb/bert-mini/bert_config.json', output_hidden_states=True)\n",
    "        self.l1 = BertModel.from_pretrained('C:/Users/Administrator/Desktop/Pytorch_BERT/emb/bert-mini/pytorch_model.bin', config=self.config)\n",
    "        self.bilstm1 = torch.nn.LSTM(512, 64, 1, bidirectional=True)\n",
    "        self.l2 = torch.nn.Linear(128, 64)\n",
    "        self.a1 = torch.nn.ReLU()\n",
    "        self.l3 = torch.nn.Dropout(0.3)\n",
    "        self.l4 = torch.nn.Linear(64, 14)\n",
    "    \n",
    "    def forward(self, ids, mask, token_type_ids):\n",
    "        # 正确解包输出以获取 hidden_states\n",
    "        # 当 output_hidden_states=True 且 return_dict=False (默认值) 时，\n",
    "        # self.l1 返回 (sequence_output, pooler_output, hidden_states)\n",
    "        sequence_output, pooler_output, hidden_states = self.l1(ids, attention_mask=mask, token_type_ids=token_type_ids, return_dict=False) \n",
    "        \n",
    "        # [bs, 200, 256]  [bs,256]\n",
    "        bs = len(sequence_output)\n",
    "        h12 = hidden_states[-1][:,0].view(1,bs,256)\n",
    "        h11 = hidden_states[-2][:,0].view(1,bs,256)\n",
    "        concat_hidden = torch.cat((h12,h11),2)\n",
    "        x, _ = self.bilstm1(concat_hidden)\n",
    "        x = self.l2(x.view(bs,128))\n",
    "        x = self.a1(x)\n",
    "        x = self.l3(x)\n",
    "        output = self.l4(x)\n",
    "        return output\n",
    "\n",
    "net = BERTClass()\n",
    "net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 超参数设置\n",
    "lr, num_epochs = 1e-5, 30\n",
    "criterion = torch.nn.CrossEntropyLoss()  # 选择损失函数\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)  # 选择优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_accuracy(data_iter, net, device=torch.device('cpu')):\n",
    "    \"\"\"Evaluate accuracy of a model on the given data set.\"\"\"\n",
    "    acc_sum, n = torch.tensor([0], dtype=torch.float32,device=device), 0\n",
    "    y_pred_, y_true_ = [], []\n",
    "    for data in tqdm(data_iter):\n",
    "        # If device is the GPU, copy the data to the GPU.\n",
    "        ids = data['ids'].to(device, dtype = torch.long)\n",
    "        mask = data['mask'].to(device, dtype = torch.long)\n",
    "        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)\n",
    "        targets = data['targets'].to(device, dtype = torch.float)\n",
    "        net.eval()\n",
    "        y_hat_ = net(ids, mask, token_type_ids)\n",
    "        with torch.no_grad():\n",
    "            targets = targets.long()\n",
    "            # [[0.2 ,0.4 ,0.5 ,0.6 ,0.8] ,[ 0.1,0.2 ,0.4 ,0.3 ,0.1]] => [ 4 , 2 ]\n",
    "            acc_sum += torch.sum((torch.argmax(y_hat_, dim=1) == targets))\n",
    "            y_pred_.extend(torch.argmax(y_hat_, dim=1).cpu().numpy().tolist())\n",
    "            y_true_.extend(targets.cpu().numpy().tolist())\n",
    "            n += targets.shape[0]\n",
    "    valid_f1 = metrics.f1_score(y_true_, y_pred_, average='macro')\n",
    "    return acc_sum.item()/n, valid_f1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch,train_iter, test_iter, criterion, num_epochs, optimizer, device):\n",
    "    print('training on', device)\n",
    "    net.to(device)\n",
    "    best_test_f1 = 0\n",
    "    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # 设置学习率下降策略\n",
    "#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=2e-06)  # 余弦退火\n",
    "    for epoch in range(num_epochs):\n",
    "        train_l_sum = torch.tensor([0.0], dtype=torch.float32, device=device)\n",
    "        train_acc_sum = torch.tensor([0.0], dtype=torch.float32, device=device)\n",
    "        n, start = 0, time.time()\n",
    "        y_pred, y_true = [], []\n",
    "        for data in tqdm(train_iter):\n",
    "            net.train()\n",
    "            optimizer.zero_grad()\n",
    "            ids = data['ids'].to(device, dtype=torch.long)\n",
    "            mask = data['mask'].to(device, dtype=torch.long)\n",
    "            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)\n",
    "            targets = data['targets'].to(device, dtype = torch.float)\n",
    "            y_hat = net(ids, mask, token_type_ids)\n",
    "            loss = criterion(y_hat, targets.long())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                targets = targets.long()\n",
    "                train_l_sum += loss.float()\n",
    "                train_acc_sum += (torch.sum((torch.argmax(y_hat, dim=1) == targets))).float()\n",
    "                y_pred.extend(torch.argmax(y_hat, dim=1).cpu().numpy().tolist())\n",
    "                y_true.extend(targets.cpu().numpy().tolist())\n",
    "                n += targets.shape[0]\n",
    "        valid_acc, valid_f1 = evaluate_accuracy(test_iter, net, device)\n",
    "        train_acc = train_acc_sum / n\n",
    "        train_f1 = metrics.f1_score(y_true, y_pred, average='macro')\n",
    "        print('epoch %d, loss %.4f, train acc %.3f, valid acc %.3f, '\n",
    "              'train f1 %.3f, valid f1 %.3f, time %.1f sec'\n",
    "              % (epoch + 1, train_l_sum / n, train_acc, valid_acc,\n",
    "                 train_f1, valid_f1, time.time() - start))\n",
    "        if valid_f1 > best_test_f1:\n",
    "            print('find best! save at model/best.pth')\n",
    "            best_test_f1 = valid_f1\n",
    "            torch.save(net.state_dict(), 'C:/Users/Administrator/Desktop/Pytorch_BERT/model/best.pth')\n",
    "        scheduler.step()  # 更新学习率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 5000/5000 [3:29:33<00:00,  2.51s/it]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 2500/2500 [19:47<00:00,  2.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 0.0317, train acc 0.794, valid acc 0.901, train f1 0.562, valid f1 0.749, time 13760.9 sec\n",
      "find best! save at model/best.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                         | 0/5000 [00:00<?, ?it/s]D:\\anaconda3\\envs\\tf-bert\\lib\\site-packages\\transformers\\tokenization_utils_base.py:2645: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
      "  warnings.warn(\n",
      "100%|████████████████████████████████████████████████████████████████████████████| 5000/5000 [3:18:02<00:00,  2.38s/it]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 2500/2500 [40:37<00:00,  1.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, loss 0.0128, train acc 0.904, valid acc 0.925, train f1 0.759, valid f1 0.827, time 14320.4 sec\n",
      "find best! save at model/best.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                         | 0/5000 [00:00<?, ?it/s]D:\\anaconda3\\envs\\tf-bert\\lib\\site-packages\\transformers\\tokenization_utils_base.py:2645: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
      "  warnings.warn(\n",
      " 18%|█████████████▋                                                               | 890/5000 [38:25<2:57:25,  2.59s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[29], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnet\u001b[49m\u001b[43m,\u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalid_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcriterion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[1;32mIn[28], line 12\u001b[0m, in \u001b[0;36mtrain\u001b[1;34m(epoch, train_iter, test_iter, criterion, num_epochs, optimizer, device)\u001b[0m\n\u001b[0;32m     10\u001b[0m n, start \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m, time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m     11\u001b[0m y_pred, y_true \u001b[38;5;241m=\u001b[39m [], []\n\u001b[1;32m---> 12\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m tqdm(train_iter):\n\u001b[0;32m     13\u001b[0m     net\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[0;32m     14\u001b[0m     optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
      "File \u001b[1;32mD:\\anaconda3\\envs\\tf-bert\\lib\\site-packages\\tqdm\\std.py:1181\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m   1178\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[0;32m   1180\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m-> 1181\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[0;32m   1182\u001b[0m         \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[0;32m   1183\u001b[0m         \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[0;32m   1184\u001b[0m         \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
      "File \u001b[1;32mD:\\anaconda3\\envs\\tf-bert\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:701\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    698\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    699\u001b[0m     \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m    700\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()  \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 701\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    702\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m    703\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[0;32m    704\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable\n\u001b[0;32m    705\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m    706\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called\n\u001b[0;32m    707\u001b[0m ):\n",
      "File \u001b[1;32mD:\\anaconda3\\envs\\tf-bert\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:757\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    755\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m    756\u001b[0m     index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index()  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m--> 757\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m    758\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[0;32m    759\u001b[0m         data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n",
      "File \u001b[1;32mD:\\anaconda3\\envs\\tf-bert\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:52\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m     50\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[0;32m     51\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 52\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m     53\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m     54\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
      "File \u001b[1;32mD:\\anaconda3\\envs\\tf-bert\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:52\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m     50\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[0;32m     51\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 52\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m     53\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m     54\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
      "Cell \u001b[1;32mIn[10], line 16\u001b[0m, in \u001b[0;36mCustomDataset.__getitem__\u001b[1;34m(self, index)\u001b[0m\n\u001b[0;32m     13\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, index):\n\u001b[0;32m     14\u001b[0m     comment_text \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcomment_text[index]\n\u001b[1;32m---> 16\u001b[0m     inputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode_plus\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m     17\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcomment_text\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m     18\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtruncation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m     19\u001b[0m \u001b[43m        \u001b[49m\u001b[43madd_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m     20\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_len\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m     21\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpad_to_max_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m     22\u001b[0m \u001b[43m        \u001b[49m\u001b[43mreturn_token_type_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[0;32m     23\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     24\u001b[0m     ids \u001b[38;5;241m=\u001b[39m inputs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m     25\u001b[0m     mask \u001b[38;5;241m=\u001b[39m inputs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m'\u001b[39m]\n",
      "File \u001b[1;32mD:\\anaconda3\\envs\\tf-bert\\lib\\site-packages\\transformers\\tokenization_utils_base.py:3008\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.encode_plus\u001b[1;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[0;32m   2998\u001b[0m \u001b[38;5;66;03m# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'\u001b[39;00m\n\u001b[0;32m   2999\u001b[0m padding_strategy, truncation_strategy, max_length, kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_padding_truncation_strategies(\n\u001b[0;32m   3000\u001b[0m     padding\u001b[38;5;241m=\u001b[39mpadding,\n\u001b[0;32m   3001\u001b[0m     truncation\u001b[38;5;241m=\u001b[39mtruncation,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   3005\u001b[0m     \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[0;32m   3006\u001b[0m )\n\u001b[1;32m-> 3008\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_encode_plus(\n\u001b[0;32m   3009\u001b[0m     text\u001b[38;5;241m=\u001b[39mtext,\n\u001b[0;32m   3010\u001b[0m     text_pair\u001b[38;5;241m=\u001b[39mtext_pair,\n\u001b[0;32m   3011\u001b[0m     add_special_tokens\u001b[38;5;241m=\u001b[39madd_special_tokens,\n\u001b[0;32m   3012\u001b[0m     padding_strategy\u001b[38;5;241m=\u001b[39mpadding_strategy,\n\u001b[0;32m   3013\u001b[0m     truncation_strategy\u001b[38;5;241m=\u001b[39mtruncation_strategy,\n\u001b[0;32m   3014\u001b[0m     max_length\u001b[38;5;241m=\u001b[39mmax_length,\n\u001b[0;32m   3015\u001b[0m     stride\u001b[38;5;241m=\u001b[39mstride,\n\u001b[0;32m   3016\u001b[0m     is_split_into_words\u001b[38;5;241m=\u001b[39mis_split_into_words,\n\u001b[0;32m   3017\u001b[0m     pad_to_multiple_of\u001b[38;5;241m=\u001b[39mpad_to_multiple_of,\n\u001b[0;32m   3018\u001b[0m     return_tensors\u001b[38;5;241m=\u001b[39mreturn_tensors,\n\u001b[0;32m   3019\u001b[0m     return_token_type_ids\u001b[38;5;241m=\u001b[39mreturn_token_type_ids,\n\u001b[0;32m   3020\u001b[0m     return_attention_mask\u001b[38;5;241m=\u001b[39mreturn_attention_mask,\n\u001b[0;32m   3021\u001b[0m     return_overflowing_tokens\u001b[38;5;241m=\u001b[39mreturn_overflowing_tokens,\n\u001b[0;32m   3022\u001b[0m     return_special_tokens_mask\u001b[38;5;241m=\u001b[39mreturn_special_tokens_mask,\n\u001b[0;32m   3023\u001b[0m     return_offsets_mapping\u001b[38;5;241m=\u001b[39mreturn_offsets_mapping,\n\u001b[0;32m   3024\u001b[0m     return_length\u001b[38;5;241m=\u001b[39mreturn_length,\n\u001b[0;32m   3025\u001b[0m     verbose\u001b[38;5;241m=\u001b[39mverbose,\n\u001b[0;32m   3026\u001b[0m     \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[0;32m   3027\u001b[0m )\n",
      "File \u001b[1;32mD:\\anaconda3\\envs\\tf-bert\\lib\\site-packages\\transformers\\tokenization_utils.py:719\u001b[0m, in \u001b[0;36mPreTrainedTokenizer._encode_plus\u001b[1;34m(self, text, text_pair, add_special_tokens, padding_strategy, truncation_strategy, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[0;32m    710\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_offsets_mapping:\n\u001b[0;32m    711\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\n\u001b[0;32m    712\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mreturn_offset_mapping is not available when using Python tokenizers. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    713\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTo use this feature, change your tokenizer to one deriving from \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    716\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://github.com/huggingface/transformers/pull/2674\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    717\u001b[0m     )\n\u001b[1;32m--> 719\u001b[0m first_ids \u001b[38;5;241m=\u001b[39m \u001b[43mget_input_ids\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    720\u001b[0m second_ids \u001b[38;5;241m=\u001b[39m get_input_ids(text_pair) \u001b[38;5;28;01mif\u001b[39;00m text_pair \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m    722\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprepare_for_model(\n\u001b[0;32m    723\u001b[0m     first_ids,\n\u001b[0;32m    724\u001b[0m     pair_ids\u001b[38;5;241m=\u001b[39msecond_ids,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    738\u001b[0m     verbose\u001b[38;5;241m=\u001b[39mverbose,\n\u001b[0;32m    739\u001b[0m )\n",
      "File \u001b[1;32mD:\\anaconda3\\envs\\tf-bert\\lib\\site-packages\\transformers\\tokenization_utils.py:687\u001b[0m, in \u001b[0;36mPreTrainedTokenizer._encode_plus.<locals>.get_input_ids\u001b[1;34m(text)\u001b[0m\n\u001b[0;32m    685\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(text, \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m    686\u001b[0m     tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtokenize(text, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m--> 687\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconvert_tokens_to_ids\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    688\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(text, (\u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mtuple\u001b[39m)) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(text) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(text[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m    689\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m is_split_into_words:\n",
      "File \u001b[1;32mD:\\anaconda3\\envs\\tf-bert\\lib\\site-packages\\transformers\\tokenization_utils.py:649\u001b[0m, in \u001b[0;36mPreTrainedTokenizer.convert_tokens_to_ids\u001b[1;34m(self, tokens)\u001b[0m\n\u001b[0;32m    647\u001b[0m ids \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m    648\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m token \u001b[38;5;129;01min\u001b[39;00m tokens:\n\u001b[1;32m--> 649\u001b[0m     ids\u001b[38;5;241m.\u001b[39mappend(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_convert_token_to_id_with_added_voc(token))\n\u001b[0;32m    650\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "train(net,train_loader, valid_loader, criterion, num_epochs, optimizer, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Administrator\\AppData\\Local\\Temp\\ipykernel_12920\\2473937819.py:5: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  net.load_state_dict(torch.load('model/best.pth'))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "加载最优模型\n",
      "inference测试集\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                         | 0/3125 [00:00<?, ?it/s]D:\\anaconda3\\envs\\tf-bert\\lib\\site-packages\\transformers\\tokenization_utils_base.py:2645: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
      "  warnings.warn(\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 3125/3125 [24:14<00:00,  2.15it/s]\n"
     ]
    }
   ],
   "source": [
    "def model_predict(net, test_iter):\n",
    "    # 预测模型\n",
    "    preds_list = []\n",
    "    print('加载最优模型')\n",
    "    net.load_state_dict(torch.load('model/best.pth'))\n",
    "    net = net.to(device)\n",
    "    print('inference测试集')\n",
    "    with torch.no_grad():\n",
    "        for data in tqdm(test_iter):\n",
    "            ids = data['ids'].to(device, dtype=torch.long)\n",
    "            mask = data['mask'].to(device, dtype=torch.long)\n",
    "            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)\n",
    "            batch_preds = list(net(ids, mask, token_type_ids).argmax(dim=1).cpu().numpy())\n",
    "            for preds in batch_preds:\n",
    "                preds_list.append(preds)           \n",
    "    return preds_list\n",
    "preds_list = model_predict(net, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "submission = pd.read_csv('C:/Users/Administrator/Desktop/Pytorch_BERT/data/test_a_sample_submit.csv')\n",
    "submission['label'] = preds_list\n",
    "submission.to_csv('C:/Users/Administrator/Desktop/Pytorch_BERT/output/submission.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.23"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
